{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "OiBSu3YkEcoX"
   },
   "source": [
    "Copyright 2024 DeepMind Technologies Limited.\n",
    "\n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "you may not use this file except in compliance with the License.\n",
    "You may obtain a copy of the License at\n",
    "\n",
    "   http://www.apache.org/licenses/LICENSE-2.0\n",
    "\n",
    "Unless required by applicable law or agreed to in writing, software\n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.\n",
    "See the License for the specific language governing permissions and\n",
    "limitations under the License.\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Y5OeTiryEcoX"
   },
   "source": [
    "# Fine-tuning the 2B Gemma model with flax\n",
    "\n",
    "In this tutorial you will learn how to fine-tune the 2B Gemma model for a simple translation task. To run this colab, you will need to use a TPU v4 runtime."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5m81VQOqEcoX"
   },
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "XpSw-_4EEcoY"
   },
   "outputs": [],
   "source": [
    "# @title Installation\n",
    "! pip install git+https://github.com/google-deepmind/gemma.git\n",
    "! pip install --user kaggle"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "iLafhtv3Rg5F"
   },
   "source": [
    "## Downloading the checkpoint\n",
    "\n",
    "\"To use Gemma's checkpoints, you'll need a Kaggle account and API key. Here's how to get them:\n",
    "\n",
    "1. Visit https://www.kaggle.com/ and create an account.\n",
    "2. Go to your account settings, then the 'API' section.\n",
    "3. Click 'Create new token' to download your key.\n",
    "\n",
    "Then run the cell below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "8q5seOhcUBhx"
   },
   "outputs": [],
   "source": [
    "import kagglehub\n",
    "kagglehub.login()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "jCZSmEVDVv6O"
   },
   "source": [
    "If everything went well, you should see:\n",
    "```\n",
    "Kaggle credentials set.\n",
    "Kaggle credentials successfully validated.\n",
    "```\n",
    "\n",
    "Now select and download the checkpoint you want to try. On a single host, only the 2b model can fit in memory for fine-tuning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "9PEefz8wEcoY"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:\"string\"}\n",
    "weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')\n",
    "\n",
    "ckpt_path = os.path.join(weights_dir, variant)\n",
    "vocab_path = os.path.join(weights_dir, 'tokenizer.model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "yWaP_LPoEcoY"
   },
   "outputs": [],
   "source": [
    "# @title Python imports\n",
    "\n",
    "import enum\n",
    "import re\n",
    "import string\n",
    "\n",
    "# We import JAX and some related packages.\n",
    "import chex\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import optax\n",
    "\n",
    "# We will use tensorflow to handle the dataset\n",
    "import tensorflow as tf\n",
    "import tensorflow_datasets as tfds\n",
    "\n",
    "# Finally, we import Gemma.\n",
    "from gemma import params as params_lib\n",
    "from gemma import sampler as sampler_lib\n",
    "from gemma import transformer as transformer_lib\n",
    "import sentencepiece as spm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ejQhgtjbEcoY"
   },
   "source": [
    "## Step 1: prepare the dataset\n",
    "\n",
    "### The MTNT dataset\n",
    "\n",
    "In this tutorial, we will use the MTNT dataset, from the paper [MTNT: A Testbed for Machine Translation of Noisy Text](https://arxiv.org/abs/1809.00388). This dataset is directly available in the [TensorFlow dataset catalog](https://www.tensorflow.org/datasets/catalog/mtnt).\n",
    "\n",
    "More precisely we will focus on the English to French translation.\n",
    "\n",
    "But let's have a look at the data themselves."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "cellView": "form",
    "id": "pg8SfQH0EcoY"
   },
   "outputs": [],
   "source": [
    "ds = tfds.load(\"mtnt/en-fr\", split=\"train\")\n",
    "ds = ds.take(2)\n",
    "ds = ds.as_numpy_iterator()\n",
    "for idx, example in enumerate(ds):\n",
    "  print(f'Example {idx}:')\n",
    "  for key, val in example.items():\n",
    "    print(f'{key}: {val}')\n",
    "  print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "aYy4EJDsEcoY"
   },
   "source": [
    "Each sample in the dataset contains two entries:\n",
    "- 'src': The original English sentence.\n",
    "- 'dst': The corresponding French translation."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "NYC42hJgEcoY"
   },
   "source": [
    "### Tokenizer\n",
    "\n",
    "Let's start by loading our vocabulary base tokenizer, which we'll construct using the [SentencePiece](https://github.com/google/sentencepiece) library."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "cellView": "form",
    "id": "TpyG5YW1EcoY"
   },
   "outputs": [],
   "source": [
    "vocab = spm.SentencePieceProcessor()\n",
    "vocab.Load(vocab_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Ab2MSf-qEcoY"
   },
   "source": [
    "Let's customize `SentencePieceProcessor` for our English-to-French translation task. Since we're fine-tuning the English-only Gemma 2B model, we need a few adjustments:\n",
    "\n",
    "- **Input Prefix**: Adding a common prefix to each input signals the translation task. For example we could go with a prompt like `Translate this into French: [INPUT_SENTENCE]`.\n",
    "\n",
    "- **Translation Start suffix**: We add a suffix at the end of each prompt tells the model exactly when to begin the translation process. A new line should do the job.\n",
    "\n",
    "- **LM Tokens**: Gemma models expect a *beginning of sequence* token at the beginning of each sequence. Similarly, we need to add an *end of sequence* token at the end of each training example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "cellView": "form",
    "id": "L9cjK0uxEcoY"
   },
   "outputs": [],
   "source": [
    "class GemmaTokenizer:\n",
    "  \"\"\"Custom wrapper around a SentencePieceProcessor for tensorflow.\"\"\"\n",
    "\n",
    "  def __init__(self,\n",
    "               spm_processor: spm.SentencePieceProcessor):\n",
    "    self._spm_processor = spm_processor\n",
    "\n",
    "  @property\n",
    "  def pad_id(self) -> int:\n",
    "    \"\"\"Fast access to the pad id.\"\"\"\n",
    "    return self._spm_processor.pad_id()\n",
    "\n",
    "  def tokenize(self,\n",
    "               example: str | bytes,\n",
    "               prefix: str = '',\n",
    "               suffix: str = '',\n",
    "               add_eos: bool = True) -> jax.Array:\n",
    "    \"\"\"\n",
    "    Tokenization function.\n",
    "\n",
    "    Args:\n",
    "      example: input string to tokenize.\n",
    "      prefix:  prefix to add to the input string.\n",
    "      suffix:  suffix to add to the input string.\n",
    "      add_eos: if True, add an end of sentence token at the end of the output\n",
    "               sequence.\n",
    "    Returns:\n",
    "      Tokens corresponding to the input string.\n",
    "    \"\"\"\n",
    "    int_list = [self._spm_processor.bos_id()]\n",
    "    int_list.extend(self._spm_processor.EncodeAsIds(prefix + example + suffix))\n",
    "    if add_eos:\n",
    "      int_list.append(self._spm_processor.eos_id())\n",
    "\n",
    "    return jnp.array(int_list, dtype=jnp.int32)\n",
    "\n",
    "  def tokenize_tf_op(self,\n",
    "                     str_tensor: tf.Tensor,\n",
    "                     prefix: str = '',\n",
    "                     suffix: str = '',\n",
    "                     add_eos: bool = True) -> tf.Tensor:\n",
    "    \"\"\"Tensforflow operator for the tokenize function.\"\"\"\n",
    "    encoded = tf.numpy_function(\n",
    "        self.tokenize,\n",
    "        [str_tensor, prefix, suffix, add_eos],\n",
    "        tf.int32)\n",
    "    encoded.set_shape([None])\n",
    "    return encoded\n",
    "\n",
    "  def to_string(self, tokens: jax.Array) -> str:\n",
    "    \"\"\"Convert an array of tokens to a string.\"\"\"\n",
    "    return self._spm_processor.EncodeIds(tokens.tolist())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6xuCVkurEcoY"
   },
   "source": [
    "Now let's try our custom tokenizer on the MTNT dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "cellView": "form",
    "id": "xEA-97ioEcoY"
   },
   "outputs": [],
   "source": [
    "tokenizer = GemmaTokenizer(vocab)\n",
    "\n",
    "def tokenize_source(tokenizer, example: tf.Tensor):\n",
    "  return tokenizer.tokenize_tf_op(example,\n",
    "                                  prefix='Translate this into French:\\n',\n",
    "                                  suffix='\\n',\n",
    "                                  add_eos=False)\n",
    "def tokenize_destination(tokenizer, example: tf.Tensor):\n",
    "  return tokenizer.tokenize_tf_op(example,\n",
    "                                  add_eos=True)\n",
    "\n",
    "ds = tfds.load(\"mtnt/en-fr\",split=\"train\")\n",
    "ds = ds.take(2)\n",
    "ds = ds.map(lambda x: {'src': tokenize_source(tokenizer, x['src']),\n",
    "                       'dst': tokenize_destination(tokenizer, x['dst'])})\n",
    "ds = ds.as_numpy_iterator()\n",
    "for idx, example in enumerate(ds):\n",
    "  print(f'Example {idx}:')\n",
    "  for key, val in example.items():\n",
    "    print(f'{key}: {val}')\n",
    "  print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "r-x0aTugEcoY"
   },
   "source": [
    "### Data loader\n",
    "\n",
    "We can now wrap everything a build our data loader."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "cellView": "form",
    "id": "XwFFs2mDEcoY"
   },
   "outputs": [],
   "source": [
    "@chex.dataclass(frozen=True)\n",
    "class TrainingInput:\n",
    "  # Input tokens given to the model\n",
    "  input_tokens: jax.Array\n",
    "\n",
    "  # A mask that determines which tokens contribute to the target loss\n",
    "  # calculation.\n",
    "  target_mask: jax.Array\n",
    "\n",
    "class DatasetSplit(enum.Enum):\n",
    "  TRAIN = 'train'\n",
    "  VALIDATION = 'valid'\n",
    "\n",
    "\n",
    "class MTNTDatasetBuilder:\n",
    "  \"\"\"Data loader for the MTNT dataset.\"\"\"\n",
    "\n",
    "  N_ITEMS = {DatasetSplit.TRAIN: 35_692,\n",
    "             DatasetSplit.VALIDATION: 811}\n",
    "\n",
    "  BUFFER_SIZE_SHUFFLE = 10_000\n",
    "  TRANSLATION_PREFIX = 'Translate this into French:\\n'\n",
    "  TRANSLATION_SUFFIX = '\\n'\n",
    "\n",
    "  def __init__(self,\n",
    "               tokenizer : GemmaTokenizer,\n",
    "               max_seq_len: int):\n",
    "    \"\"\"Constructor.\n",
    "\n",
    "    Args:\n",
    "      tokenizer: Gemma tokenizer to use.\n",
    "      max_seq_len: size of each sequence in a given batch.\n",
    "    \"\"\"\n",
    "    self._tokenizer = tokenizer\n",
    "    self._base_data = {\n",
    "        DatasetSplit.TRAIN: tfds.load(\"mtnt/en-fr\",split=\"train\"),\n",
    "        DatasetSplit.VALIDATION: tfds.load(\"mtnt/en-fr\",split=\"valid\"),\n",
    "    }\n",
    "    self._max_seq_len = max_seq_len\n",
    "\n",
    "  def _tokenize_source(self, example: tf.Tensor):\n",
    "    \"\"\"Tokenization function for the source.\"\"\"\n",
    "    return self._tokenizer.tokenize_tf_op(example,\n",
    "                                          prefix=self.TRANSLATION_PREFIX,\n",
    "                                          suffix=self.TRANSLATION_SUFFIX,\n",
    "                                          add_eos=False)\n",
    "\n",
    "  def _tokenize_destination(self, example: tf.Tensor):\n",
    "    \"\"\"Tokenization function for the French translation.\"\"\"\n",
    "    return self._tokenizer.tokenize_tf_op(example,\n",
    "                                          add_eos=True)\n",
    "\n",
    "  def _pad_up_to_max_len(self,\n",
    "                         input_tensor: tf.Tensor,\n",
    "                         pad_value: int | bool,\n",
    "                         ) -> tf.Tensor:\n",
    "    \"\"\"Pad the given tensor up to sequence length of a batch.\"\"\"\n",
    "    seq_len = tf.shape(input_tensor)[0]\n",
    "    to_pad = tf.maximum(self._max_seq_len - seq_len, 0)\n",
    "    return tf.pad(input_tensor,\n",
    "                  [[0, to_pad]],\n",
    "                  mode='CONSTANT',\n",
    "                  constant_values=pad_value,\n",
    "                  )\n",
    "\n",
    "  def _to_training_input(self,\n",
    "                         src_tokens: jax.Array,\n",
    "                         dst_tokens: jax.Array,\n",
    "                         ) -> TrainingInput:\n",
    "    \"\"\"Build a training input from a tuple of source and destination tokens.\"\"\"\n",
    "\n",
    "    # The input sequence fed to the model is simply the concatenation of the\n",
    "    # source and the destination.\n",
    "    tokens = tf.concat([src_tokens, dst_tokens], axis=0)\n",
    "\n",
    "    # We want to prevent the model from updating based on the source (input)\n",
    "    # tokens. To achieve this, we add a target mask to each input.\n",
    "    q_mask = tf.zeros_like(src_tokens, dtype=tf.bool)\n",
    "    a_mask = tf.ones_like(dst_tokens, dtype=tf.bool)\n",
    "    mask = tf.concat([q_mask, a_mask], axis=0)\n",
    "\n",
    "    # If the output tokens sequence is smaller than the target sequence size,\n",
    "    # then we pad it with pad tokens.\n",
    "    tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)\n",
    "\n",
    "    # We don't want to perform the backward on the pad tokens.\n",
    "    mask = self._pad_up_to_max_len(mask, False)\n",
    "\n",
    "    return TrainingInput(input_tokens=tokens, target_mask=mask)\n",
    "\n",
    "\n",
    "  def get_train_dataset(self, batch_size: int, num_epochs: int):\n",
    "    \"\"\"Build the training dataset.\"\"\"\n",
    "\n",
    "    # Tokenize each sample\n",
    "    ds = self._base_data[DatasetSplit.TRAIN].map(lambda x : (self._tokenize_source(x['src']),\n",
    "                                                             self._tokenize_destination(x['dst'])))\n",
    "\n",
    "    # Convert them to training inputs\n",
    "    ds = ds.map(lambda x, y: self._to_training_input(x, y))\n",
    "\n",
    "    # Remove the samples which are too long\n",
    "    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)\n",
    "\n",
    "    # Shuffle the dataset\n",
    "    ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)\n",
    "\n",
    "    # Repeat if necessary\n",
    "    ds = ds.repeat(num_epochs)\n",
    "\n",
    "    # Build batches\n",
    "    ds = ds.batch(batch_size, drop_remainder=True)\n",
    "    return ds\n",
    "\n",
    "  def get_validation_dataset(self, batch_size: int):\n",
    "    \"\"\"Build the validation dataset.\"\"\"\n",
    "\n",
    "    # Same as the training dataset, but no shuffling and no repetition\n",
    "    ds = self._base_data[DatasetSplit.VALIDATION].map(lambda x : (self._tokenize_source(x['src']),\n",
    "                                                                  self._tokenize_destination(x['dst'])))\n",
    "    ds = ds.map(lambda x, y: self._to_training_input(x, y))\n",
    "    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)\n",
    "    ds = ds.batch(batch_size, drop_remainder=True)\n",
    "    return ds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_Sq9uC15EcoZ"
   },
   "source": [
    "Let's give it a try."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "cellView": "form",
    "id": "bYeduOaNEcoZ"
   },
   "outputs": [],
   "source": [
    "tokenizer = GemmaTokenizer(vocab)\n",
    "dataset_builder = MTNTDatasetBuilder(tokenizer, max_seq_len=20)\n",
    "ds = dataset_builder.get_train_dataset(3, 1)\n",
    "ds = ds.take(2)\n",
    "ds = ds.as_numpy_iterator()\n",
    "for idx, example in enumerate(ds):\n",
    "  print(f'Example {idx}:')\n",
    "  for key, val in example.items():\n",
    "    print(f'{key}: {val}')\n",
    "  print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_VsT2o6JEcoZ"
   },
   "source": [
    "## Fine tuning the Gemma model\n",
    "\n",
    "### Getting started\n",
    "\n",
    "First let's load the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "cellView": "form",
    "id": "VDlfziQVEcoZ"
   },
   "outputs": [],
   "source": [
    "# Load parameters\n",
    "\n",
    "# TODO: change once the downloading url is known\n",
    "params = params_lib.load_and_format_params(ckpt_path)\n",
    "\n",
    "# We use the `transformer_lib.TransformerConfig.from_params` function to\n",
    "# automatically load the correct configuration from a checkpoint. Note that the\n",
    "# vocabulary size is smaller than the number of input embeddings due to unused\n",
    "# tokens in this release.\n",
    "config_2b = transformer_lib.TransformerConfig.from_params(\n",
    "    params,\n",
    "    cache_size=30  # Number of time steps in the transformer's cache\n",
    ")\n",
    "model_2b = transformer_lib.Transformer(config=config_2b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "cGbfx6XVEcoZ"
   },
   "source": [
    "Can our model translate French ? Well let's try it out !"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "cellView": "form",
    "id": "jWr6Sea_EcoZ"
   },
   "outputs": [],
   "source": [
    "sampler_old = sampler_lib.Sampler(\n",
    "    transformer=model_2b,\n",
    "    vocab=vocab,\n",
    "    params=params['transformer'],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "cellView": "form",
    "id": "S6937NTjEcoZ"
   },
   "outputs": [],
   "source": [
    "print(sampler_old(\n",
    "    [\"Translate this into French:\\nHello, my name is Morgane.\\n\"],\n",
    "    # number of steps performed when generating\n",
    "    total_generation_steps=30,\n",
    "  ).text)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0Z0CXW4REcoZ"
   },
   "source": [
    "As expected, it didn't work. Let's see if we can get better results by fine-tuning.\n",
    "\n",
    "Before moving further, don't forget to clear the memory if necessary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "cellView": "form",
    "id": "LbJa4S5WEcoZ"
   },
   "outputs": [],
   "source": [
    "del sampler_old"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gxf6gVGCEcoZ"
   },
   "source": [
    "### Model forward and loss function\n",
    "\n",
    "Gemma `Transformer` class inherits from [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/flax_basics.html). It offers two essential methods:\n",
    "\n",
    "- `init`: Initializes the model's parameters.\n",
    "\n",
    "- `apply`: Executes the model's `__call__` function using a given set of parameters.\n",
    "\n",
    "Since are working with pre-trained weights, we won't use the `init` function.\n",
    "\n",
    "We define a `forward_and_loss_fn` as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "cellView": "form",
    "id": "iEcV0XEEEcoZ"
   },
   "outputs": [],
   "source": [
    "def forward_and_loss_fn(params,\n",
    "                        *,\n",
    "                        model: transformer_lib.Transformer,\n",
    "                        input_tokens: jax.Array,            # Shape [B, L]\n",
    "                        input_mask: jax.Array,              # Shape [B, L]\n",
    "                        positions: jax.Array,               # Shape [B, L]\n",
    "                        attention_mask: jax.Array,          # [B, L, L]\n",
    "                        ) -> jax.Array:\n",
    "  \"\"\"Foward pass and loss function.\n",
    "\n",
    "  Args:\n",
    "    params: model's input parameters.\n",
    "    model: gemma transformer model to call.\n",
    "    input_tokens: input tokens sequence, shape [B, L].\n",
    "    input_mask: tokens to ignore when computing the loss, shape [B, L].\n",
    "    positions: relative position of each token, shape [B, L].\n",
    "    attention_mask: input attention mask, shape [B, L].\n",
    "\n",
    "  Returns:\n",
    "    Softmax cross-entropy loss for the next-token prediction task.\n",
    "  \"\"\"\n",
    "\n",
    "  # Foward pass on the input data.\n",
    "  # No attention cache is needed here.\n",
    "  logits, _ = model.apply(\n",
    "        params,\n",
    "        input_tokens,\n",
    "        positions,\n",
    "        None,              # Attention cache is None.\n",
    "        attention_mask,\n",
    "    )\n",
    "\n",
    "  # Exclude the last step as it does not appear in the targets.\n",
    "  logits = logits[0, :-1]\n",
    "\n",
    "  # Similarly, the first token cannot be predicteds.\n",
    "  target_tokens = input_tokens[0, 1:]\n",
    "  target_mask = input_mask[0, 1:]\n",
    "\n",
    "  # Convert the target labels into one-hot encoded vectors.\n",
    "  one_hot = jax.nn.one_hot(target_tokens, logits.shape[-1])\n",
    "\n",
    "  # Don't update on unwanted tokens.\n",
    "  one_hot = one_hot * target_mask.astype(one_hot.dtype)[...,None]\n",
    "\n",
    "  # Normalisation factor.\n",
    "  norm_factor = 1 / (jnp.sum(target_mask) + 1e-8)\n",
    "\n",
    "  # Return the nll loss.\n",
    "  return -jnp.sum(jax.nn.log_softmax(logits) * one_hot) * norm_factor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Y83DimpjEcoZ"
   },
   "source": [
    "The Gemma transformer requires an attention mask and position vector alongside each input. We can conveniently generate these using the following function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "cellView": "form",
    "id": "cbWfdHf0EcoZ"
   },
   "outputs": [],
   "source": [
    "def get_attention_mask_and_positions(example: jax.Array,\n",
    "                                     pad_id : int,\n",
    "                                     )-> tuple[jax.Array, jax.Array]:\n",
    "  \"\"\"Builds the position and attention mask vectors from the given tokens.\"\"\"\n",
    "  pad_mask = example != pad_id\n",
    "  current_token_position = transformer_lib.build_positions_from_mask(pad_mask)\n",
    "  attention_mask = transformer_lib.make_causal_attn_mask(pad_mask)\n",
    "  return current_token_position, attention_mask"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xbxYMMWLEcoZ"
   },
   "source": [
    "We can now build the train_step function which performs the backward pass and updates the model's parameters accordingly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "cellView": "form",
    "id": "cPSfp7ZUEcoZ"
   },
   "outputs": [],
   "source": [
    "def train_step(model: transformer_lib.Transformer,\n",
    "               params,\n",
    "               optimizer: optax.GradientTransformation,\n",
    "               opt_state: optax.OptState,\n",
    "               pad_id: int,\n",
    "               example: TrainingInput):\n",
    "  \"\"\"Train step.\n",
    "\n",
    "  Args:\n",
    "    model: gemma transformer model.\n",
    "    params: model's input parameters.\n",
    "    optimizer: optax optimizer to use.\n",
    "    opt_state: input optimizer's state.\n",
    "    pad_id: id of the pad token.\n",
    "    example: input batch.\n",
    "\n",
    "  Returns:\n",
    "    Training loss, updated parameters, updated optimizer state.\n",
    "  \"\"\"\n",
    "\n",
    "  # Build the position and attention mask vectors.\n",
    "  positions, attention_mask = get_attention_mask_and_positions(example.input_tokens, pad_id)\n",
    "\n",
    "  # Forward and backward passes\n",
    "  train_loss, grads = jax.value_and_grad(forward_and_loss_fn)(params,\n",
    "                                                             model=model,\n",
    "                                                             input_tokens=example.input_tokens,\n",
    "                                                             input_mask=example.target_mask,\n",
    "                                                             positions=positions,\n",
    "                                                             attention_mask=attention_mask)\n",
    "  # Update the parameters\n",
    "  updates, opt_state = optimizer.update(grads, opt_state)\n",
    "  params = optax.apply_updates(params, updates)\n",
    "\n",
    "  return train_loss, params, opt_state"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "R2QXp116EcoZ"
   },
   "source": [
    "Similarly, we build a `validation_step` function without backward pass."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "cellView": "form",
    "id": "yU4oR92YEcoa"
   },
   "outputs": [],
   "source": [
    "def validation_step(model: transformer_lib.Transformer,\n",
    "                    params,\n",
    "                    pad_id: int,\n",
    "                    example: TrainingInput,\n",
    "                    ):\n",
    "  positions, attention_mask = get_attention_mask_and_positions(example.input_tokens, pad_id)\n",
    "  val_loss = forward_and_loss_fn(params,\n",
    "                                 model=model,\n",
    "                                 input_tokens=example.input_tokens,\n",
    "                                 input_mask=example.target_mask,\n",
    "                                 positions=positions,\n",
    "                                 attention_mask=attention_mask)\n",
    "  return val_loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6g6LFWJbEcoa"
   },
   "source": [
    "And now the training loop itself."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "cellView": "form",
    "id": "xT4bAqNLEcoa"
   },
   "outputs": [],
   "source": [
    "@chex.dataclass(frozen=True)\n",
    "class TrainingConfig:\n",
    "  learning_rate: float\n",
    "  num_epochs: int\n",
    "  eval_every_n: int\n",
    "  batch_size: int\n",
    "  max_steps: int | None = None\n",
    "\n",
    "\n",
    "def train_loop(\n",
    "    model: transformer_lib.Transformer,\n",
    "    params,\n",
    "    dataset_builder: MTNTDatasetBuilder,\n",
    "    training_cfg: TrainingConfig):\n",
    "\n",
    "\n",
    "  # We jit the train step, making the whole loop much more efficient\n",
    "  compiled_train_step = jax.jit(train_step, static_argnames=['model', 'optimizer'])\n",
    "\n",
    "  # We do the same with the validation step\n",
    "  compiled_validation_step = jax.jit(validation_step, static_argnames=['model'])\n",
    "\n",
    "  # To save memory, we use a SGD optimizer instead of the usual Adam. Note that\n",
    "  # for this specific example SGD is more than enough.\n",
    "  optimizer = optax.sgd(training_cfg.learning_rate)\n",
    "  opt_state = optimizer.init(params)\n",
    "\n",
    "  # Build the training dataset\n",
    "  train_ds = dataset_builder.get_train_dataset(batch_size=training_cfg.batch_size,\n",
    "                                               num_epochs=training_cfg.num_epochs)\n",
    "  train_ds = train_ds.as_numpy_iterator()\n",
    "\n",
    "  # Build the validation dataset, with a limited number of samples for this demo\n",
    "  validation_ds = dataset_builder.get_validation_dataset(batch_size=training_cfg.batch_size)\n",
    "  validation_ds = validation_ds.take(50)\n",
    "\n",
    "  n_steps = 0\n",
    "  avg_loss=0\n",
    "\n",
    "  # A first round of validation loss\n",
    "  n_steps_eval = 0\n",
    "  eval_loss = 0\n",
    "  val_iterator = validation_ds.as_numpy_iterator()\n",
    "  for val_example in val_iterator:\n",
    "    eval_loss += compiled_validation_step(model,\n",
    "                                          params,\n",
    "                                          dataset_builder._tokenizer.pad_id,\n",
    "                                          val_example)\n",
    "    n_steps_eval += 1\n",
    "  print(f\"Start, validation loss: {eval_loss/n_steps_eval}\")\n",
    "\n",
    "  for train_example in train_ds:\n",
    "    train_loss, params, opt_state = compiled_train_step(model=model,\n",
    "                                                        params=params,\n",
    "                                                        optimizer=optimizer,\n",
    "                                                        opt_state=opt_state,\n",
    "                                                        pad_id=dataset_builder._tokenizer.pad_id,\n",
    "                                                        example=train_example)\n",
    "    n_steps += 1\n",
    "    avg_loss += train_loss\n",
    "    if n_steps % training_cfg.eval_every_n == 0:\n",
    "      eval_loss = 0\n",
    "\n",
    "      n_steps_eval = 0\n",
    "      val_iterator = validation_ds.as_numpy_iterator()\n",
    "      for val_example in val_iterator:\n",
    "        eval_loss += compiled_validation_step(model,\n",
    "                                              params,\n",
    "                                              dataset_builder._tokenizer.pad_id,\n",
    "                                              val_example)\n",
    "        n_steps_eval +=1\n",
    "      avg_loss /= training_cfg.eval_every_n\n",
    "      eval_loss /= n_steps_eval\n",
    "      print(f\"STEP {n_steps} training loss: {avg_loss} - eval loss: {eval_loss}\")\n",
    "      avg_loss=0\n",
    "    if training_cfg.max_steps is not None and n_steps > training_cfg.max_steps:\n",
    "      break\n",
    "  return params"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "muwkf_ZgEcoa"
   },
   "source": [
    "We can fine-tune our model on a limited number of steps."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "cellView": "form",
    "id": "7SL2VAmVEcoa"
   },
   "outputs": [],
   "source": [
    "# Small seq size so that everything fits in memory\n",
    "SEQ_SIZE = 25\n",
    "tokenizer = GemmaTokenizer(vocab)\n",
    "dataset_builder= MTNTDatasetBuilder(tokenizer, SEQ_SIZE)\n",
    "training_cfg = TrainingConfig(learning_rate=1e-4,\n",
    "                              num_epochs=1,\n",
    "                              eval_every_n=20,\n",
    "                              batch_size=1,\n",
    "                              max_steps=100)\n",
    "\n",
    "params = train_loop(model=model_2b,\n",
    "                    params={'params': params['transformer']},\n",
    "                    dataset_builder=dataset_builder,\n",
    "                    training_cfg=training_cfg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "abChlybFEcod"
   },
   "source": [
    "Both the training loss and the validation's are going down. But is it working ? Let's try again with our previous example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "cellView": "form",
    "id": "dQ1oCF10Ecod"
   },
   "outputs": [],
   "source": [
    "sampler = sampler_lib.Sampler(\n",
    "    transformer=model_2b,\n",
    "    vocab=vocab,\n",
    "    params=params['params'],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "fIwhAvMsEcod"
   },
   "source": [
    "To ensure our input matches the training format, remember to use the prefix 'Translate this into French:\\n'  and a newline character at the end. This signals the model to begin translation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "cellView": "form",
    "id": "S5F3fk22Ecod"
   },
   "outputs": [],
   "source": [
    "sampler(\n",
    "    [\"Translate this into French:\\nHello, my name is Morgane.\\n\"],\n",
    "    total_generation_steps=30,\n",
    "    ).text\n"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "private_outputs": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
